{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5191087559849992\n",
      "254386\n",
      "228947 25439\n",
      "0.5188362372077381 0.5215613821298007\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import itertools\n",
    "import pprint\n",
    "import pickle\n",
    "\n",
    "ori_train_csv = pd.read_csv('./files/train.csv')\n",
    "print(ori_train_csv['label'].mean())\n",
    "print(len(ori_train_csv))\n",
    "\n",
    "thres = int(len(ori_train_csv) * 0.9)\n",
    "train_csv = ori_train_csv[:thres]\n",
    "val_csv = ori_train_csv[thres:]\n",
    "print(len(train_csv), len(val_csv))\n",
    "print(train_csv['label'].mean(), val_csv['label'].mean())\n",
    "test_csv = pd.read_csv('./files/test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg Word Len: 6.2 | Max Word Len: 39\n",
      "Avg Char Len: 10.3 | Max Char Len: 58\n",
      "{'Q000000': ['W05733', 'W05284', 'W09158', 'W14968', 'W07863']}\n",
      "{'Q000000': ['L1128', 'L1861', 'L2218', 'L1796', 'L1055', 'L0847', 'L2927']}\n"
     ]
    }
   ],
   "source": [
    "def glance(d, n=1):\n",
    "    return dict(itertools.islice(d.items(), 1))\n",
    "\n",
    "\n",
    "def fn(path):\n",
    "    _q2w, _q2c = {}, {}\n",
    "    _w_lens, _c_lens = [], []\n",
    "\n",
    "    with open(path) as f:\n",
    "        next(f)\n",
    "        for line in f:\n",
    "            l_split = line.split(',')\n",
    "            qid, words, chars = l_split\n",
    "\n",
    "            words_sp = words.split()\n",
    "            chars_sp = chars.split()\n",
    "\n",
    "            _q2w[qid] = words_sp\n",
    "            _q2c[qid] = chars_sp\n",
    "\n",
    "            _w_lens.append(len(words_sp))\n",
    "            _c_lens.append(len(chars_sp))\n",
    "            \n",
    "    return _q2w, _q2c, _w_lens, _c_lens\n",
    "\n",
    "\n",
    "_q2w, _q2c, _w_lens, _c_lens = fn('./files/question.csv')\n",
    "\n",
    "print(\"Avg Word Len: %.1f | Max Word Len: %d\"%(sum(_w_lens)/len(_w_lens), max(_w_lens)))\n",
    "print(\"Avg Char Len: %.1f | Max Char Len: %d\"%(sum(_c_lens)/len(_c_lens), max(_c_lens)))\n",
    "\n",
    "assert len(_q2w) == len(_q2c), \"len(q2w): %d, len(q2c): %d\" % (len(_q2w), len(_q2c))\n",
    "pprint.pprint(glance(_q2w))\n",
    "pprint.pprint(glance(_q2c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./files/question.csv') as f:\n",
    "    st = f.read()\n",
    "    st = st.replace('W0', '')\n",
    "    st = st.replace('L0', '')\n",
    "    st = st.replace('W', '')\n",
    "    st = st.replace('L', '')\n",
    "with open('./question_cleaned.csv', 'w') as f:\n",
    "    f.write(st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Avg Word Len: 6.2 | Max Word Len: 39\n",
      "Avg Char Len: 10.3 | Max Char Len: 58\n"
     ]
    }
   ],
   "source": [
    "def save_obj(obj, path):\n",
    "    with open(path, 'wb') as f:\n",
    "        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)\n",
    "\n",
    "def load_obj(path):\n",
    "    with open(path, 'rb') as f:\n",
    "        return pickle.load(f)\n",
    "\n",
    "q2w, q2c, w_lens, c_lens = fn('./question_cleaned.csv')\n",
    "\n",
    "print(\"Avg Word Len: %.1f | Max Word Len: %d\"%(sum(w_lens)/len(w_lens), max(w_lens)))\n",
    "print(\"Avg Char Len: %.1f | Max Char Len: %d\"%(sum(c_lens)/len(c_lens), max(c_lens)))\n",
    "\n",
    "assert len(q2w) == len(q2c), \"len(q2w): %d, len(q2c): %d\" % (len(q2w), len(q2c))\n",
    "\n",
    "save_obj(q2w, './q2w.pkl')\n",
    "save_obj(q2c, './q2c.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Q000000': ['5733', '5284', '9158', '14968', '7863']}\n",
      "{'Q000000': ['1128', '1861', '2218', '1796', '1055', '847', '2927']}\n"
     ]
    }
   ],
   "source": [
    "q2w = load_obj('./q2w.pkl')\n",
    "q2c = load_obj('./q2c.pkl')\n",
    "\n",
    "pprint.pprint(glance(q2w))\n",
    "pprint.pprint(glance(q2c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(20892, 300)\n"
     ]
    }
   ],
   "source": [
    "embed_vals = []\n",
    "with open('./files/word_embed.txt') as f:\n",
    "    for line in f:\n",
    "        line_sp = line.split()\n",
    "        embed_vals.append([float(num) for num in line_sp[1:]])\n",
    "embed_vals = np.asarray(embed_vals, dtype=np.float32)\n",
    "PAD_INT = embed_vals.shape[0]\n",
    "zeros = np.zeros((1,300), dtype=np.float32)\n",
    "embed_vals = np.concatenate([embed_vals, zeros])\n",
    "print(embed_vals.shape)\n",
    "np.save('./word_embed.npy', embed_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████| 25439/25439 [00:20<00:00, 1237.36it/s]\n"
     ]
    }
   ],
   "source": [
    "w_max_len = 20\n",
    "\n",
    "def fn1(str_li, int_li):\n",
    "    for i, s in enumerate(str_li[:w_max_len]):\n",
    "        int_li[i] = int(str_li[i])\n",
    "\n",
    "def train_fn(csv, path):\n",
    "    writer = tf.python_io.TFRecordWriter(path)\n",
    "    for arr_line in tqdm(csv.values, total=len(csv), ncols=70):\n",
    "        q1_id_int, q2_id_int = [PAD_INT]*w_max_len, [PAD_INT]*w_max_len\n",
    "\n",
    "        label, q1_id, q2_id = arr_line\n",
    "        fn1(q2w[q1_id], q1_id_int)\n",
    "        fn1(q2w[q2_id], q2_id_int)\n",
    "        \n",
    "        example = tf.train.Example(\n",
    "            features = tf.train.Features(\n",
    "                 feature = {\n",
    "                   'input1': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=q1_id_int)),\n",
    "                   'input2': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=q2_id_int)),\n",
    "                   'label': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=[label])),\n",
    "                   }))\n",
    "        serialized = example.SerializeToString()\n",
    "        writer.write(serialized)\n",
    "\n",
    "def test_fn(csv, path):\n",
    "    writer = tf.python_io.TFRecordWriter(path)\n",
    "    for arr_line in tqdm(csv.values, total=len(csv), ncols=70):\n",
    "        q1_id_int, q2_id_int = [PAD_INT]*w_max_len, [PAD_INT]*w_max_len\n",
    "\n",
    "        q1_id, q2_id = arr_line\n",
    "        fn1(q2w[q1_id], q1_id_int)\n",
    "        fn1(q2w[q2_id], q2_id_int)\n",
    "        \n",
    "        example = tf.train.Example(\n",
    "            features = tf.train.Features(\n",
    "                 feature = {\n",
    "                   'input1': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=q1_id_int)),\n",
    "                   'input2': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=q2_id_int)),\n",
    "                   }))\n",
    "        serialized = example.SerializeToString()\n",
    "        writer.write(serialized)\n",
    "        \n",
    "#train_fn(train_csv, './train_word.tfrecord')\n",
    "#test_fn(test_csv, './test_word.tfrecord')\n",
    "train_fn(val_csv, './val_word.tfrecord')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
